CoDA — Graph snapshots only when changed (uncertainty-aware)¶

This notebook runs acquisition → extinction with the uncertainty-aware split rule, but only draws a graph snapshot when the graph changes relative to the previous episode.

Change criterion (fast, deterministic):

  • Any change in transition tensor shape (e.g., new clone added).
  • After aggregating over actions, any change in the thresholded adjacency (same threshold as the plotter).
  • Any change in clone_dict mapping (detects clone re-wiring/merge).

This matches the visual “Graph at episode XX” you use, but avoids redundant frames.

In [12]:
import sys, numpy as np
sys.path.append('/mnt/data')
from coda_trial_by_trial_util import CoDAAgent, CoDAConfig
from spatial_environments import GridEnvRightDownNoSelf, GridEnvRightDownNoCue
from util import generate_dataset, generate_dataset_post_augmentation
In [13]:
import numpy as np

def sanitize_for_plot(env, T, eps=1e-12):
    """
    Remove clones with ~zero inbound+outbound mass and rebuild reverse map.
    Safe for growing/shrinking T during splits/merges.
    """
    if T is None or getattr(T, "ndim", 0) != 3:
        return

    S = T.shape[0]

    # Outbound mass from each state: sum over actions and next states
    out_mass = T.sum(axis=(1, 2))   # shape [S]

    # Inbound mass to each state: sum over sources and actions
    in_mass  = T.sum(axis=(0, 1))   # shape [S]

    active = (out_mass + in_mass) > eps

    # Drop clone ids that are inactive or out of bounds
    for cl in list(env.clone_dict.keys()):
        if cl >= S or not active[cl]:
            env.clone_dict.pop(cl, None)

    # Rebuild reverse mapping (parent -> latest clone)
    env.reverse_clone_dict = {parent: cl for cl, parent in env.clone_dict.items()}

def make_terminals_absorbing_for_plot(T, terminals):
    T = T.copy()
    for t in terminals:
        if t < T.shape[0]:
            T[t, :, :] = 0.0
    return T
def thresh_adj(T, thr=0.3):
    A = T.sum(axis=1)         # [S,S]
    return (A >= thr).astype(np.uint8)

def clone_dict_tuple(d):
    return tuple(sorted(d.items()))

def graph_changed(prev_T, prev_map, curr_T, curr_map, thr=0.3):
    if prev_T is None or prev_T.shape != curr_T.shape:
        return True
    A_prev = thresh_adj(prev_T, thr)
    A_curr = thresh_adj(curr_T, thr)
    if A_prev.shape != A_curr.shape:
        return True
    if (A_prev != A_curr).any():
        return True
    return prev_map != curr_map
In [14]:
# --- Config ---
CUE = 5
THRESH = 0.3               # must match env.plot_graph threshold
cfg = CoDAConfig(
    theta_split=0.6, theta_merge=0.5,
    n_threshold=8, min_presence_episodes=3, min_effective_exposure=5.0,
    confidence=0.8, 
    count_decay=0.9, 
    # trace_decay=0.9,    # makes PC recent
    # retro_decay=0.9     # makes RC recent
)


# cfg.theta_split = 0.85
N_ACQ, N_EXT = 250, 300
MAX_STEPS = 20

env = GridEnvRightDownNoSelf(cue_states=[CUE], env_size=(4,4), rewarded_terminal=[15])
agent = CoDAAgent(env, cfg)
In [15]:
def thresh_adj(T, thr=0.3):
    """Aggregate over actions and threshold to binary adjacency."""
    A = T.sum(axis=1)   # [S,S]
    if A.ndim != 2:
        # handle empty / malformed
        return None
    return (A >= thr).astype(np.uint8)

def clone_dict_tuple(d):
    """Stable tuple view of clone mapping for change detection."""
    # sort by clone_id
    return tuple(sorted(d.items()))
In [16]:
def graph_changed(prev_T, prev_clone_map, curr_T, curr_clone_map, thr=THRESH):
    if prev_T is None:
        return True
    # shape change (e.g., clones added)
    if prev_T.shape != curr_T.shape:
        return True
    # adjacency change
    A_prev = thresh_adj(prev_T, thr=thr)
    A_curr = thresh_adj(curr_T, thr=thr)
    if A_prev is None or A_curr is None:
        return True
    if A_prev.shape != A_curr.shape:
        return True
    if np.any(A_prev != A_curr):
        return True
    # clone map change
    if prev_clone_map != curr_clone_map:
        return True
    return False
In [17]:
# --- Run loops; only plot when changed ---
with_clones = False
prev_T = None
prev_map = None
changed_episodes = []

# Acquisition
for ep in range(1, N_ACQ+1):
    if with_clones:
        (states, actions) = generate_dataset_post_augmentation(env, agent.get_T(), n_episodes=1, max_steps=MAX_STEPS)[0]
    else:
        (states, actions) = generate_dataset(env, n_episodes=1, max_steps=MAX_STEPS)[0]

    agent.update_with_episode(states, actions)
    new = agent.maybe_split()
    if new:
        with_clones = True

    T_curr = agent.get_T().copy()
    map_curr = clone_dict_tuple(env.clone_dict)

    if graph_changed(prev_T, prev_map, T_curr, map_curr, thr=THRESH):





        
        sanitize_for_plot(env, T_curr)
        env.plot_graph(T_curr, niter=ep, threshold=THRESH, save=False, savename=f'graph_ep{ep}.png')
        changed_episodes.append(ep)
        prev_T, prev_map = T_curr, map_curr

print("Changed episodes (acquisition):", changed_episodes[:20], "... total:", len(changed_episodes))
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
Changed episodes (acquisition): [1, 3, 4, 5, 6, 7, 9, 12, 15, 19, 20, 21, 22, 23, 37, 44, 54, 57, 80, 84] ... total: 51
In [18]:
# Extinction / degradation
THRESH = 0.3
prev_T, prev_map = prev_T, prev_map  # reuse from acquisition if you kept them
changed_ext = []

env2 = GridEnvRightDownNoCue(cue_states=[CUE], env_size=(4,4), rewarded_terminal=[15])
env2.clone_dict = dict(agent.env.clone_dict)
env2.reverse_clone_dict = dict(agent.env.reverse_clone_dict)
agent.env = env2

for ep in range(N_ACQ+1, N_ACQ+N_EXT+1):
    (states, actions) = generate_dataset_post_augmentation(env2, agent.get_T(), n_episodes=1, max_steps=MAX_STEPS)[0]
    agent.update_with_episode(states, actions)
    agent.maybe_merge()

    T_curr  = agent.get_T().copy()
    map_curr = clone_dict_tuple(env2.clone_dict)

    if graph_changed(prev_T, prev_map, T_curr, map_curr, thr=THRESH):
        # (optional) clean terminals/clones just for the figure:
        T_vis = make_terminals_absorbing_for_plot(T_curr, env2.rewarded_terminals + env2.unrewarded_terminals)
        sanitize_for_plot(env2, T_vis)
        env2.plot_graph(T_vis, niter=ep, threshold=THRESH, save=False, savename=f'graph_ep{ep}.png')

        changed_ext.append(ep)
        prev_T, prev_map = T_curr, map_curr

print("Extinction changed episodes:", changed_ext[:30], "... total:", len(changed_ext))
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
Extinction changed episodes: [251, 255, 256, 257, 263, 265, 270, 282, 291, 296, 297, 302, 303, 321, 322, 327, 330, 332, 338, 341, 345, 352, 354, 364, 366, 367, 369, 379, 381, 382] ... total: 65

You can set save=True in plot_graph to export the changed snapshots as PNGs only for those episodes.

Metrics: KL/JS vs episode, Entropy, and Markovization¶

In [19]:
# Collect T snapshots during acquisition and extinction.
# If you already recorded them earlier, just reuse those lists.
T_series_acq = []
T_series_ext = []

# Re-run quick pass to collect snapshots only (no plotting) -- uses your existing variables:
# Acquisition
with_clones = False
for ep in range(1, N_ACQ+1):
    if with_clones:
        (states, actions) = generate_dataset_post_augmentation(env, agent.get_T(), n_episodes=1, max_steps=MAX_STEPS)[0]
    else:
        (states, actions) = generate_dataset(env, n_episodes=1, max_steps=MAX_STEPS)[0]
    agent.update_with_episode(states, actions)
    if agent.maybe_split():
        with_clones = True
    T_series_acq.append(agent.get_T().copy())

# Extinction
env2 = GridEnvRightDownNoCue(cue_states=[CUE], env_size=(4,4), rewarded_terminal=[15])
env2.clone_dict = dict(getattr(env, "clone_dict", {}))
env2.reverse_clone_dict = dict(getattr(env, "reverse_clone_dict", {}))
agent.env = env2

for ep in range(N_ACQ+1, N_ACQ+N_EXT+1):
    (states, actions) = generate_dataset_post_augmentation(env2, agent.get_T(), n_episodes=1, max_steps=MAX_STEPS)[0]
    agent.update_with_episode(states, actions)
    agent.maybe_merge()
    T_series_ext.append(agent.get_T().copy())
In [20]:
# Compute metrics using the module we prepared
from coda_metrics import kl_over_time, entropy_over_time, markovization_score, ref_empirical_from_rollouts, greedy_right_down_policy
import numpy as np

def ref_builder_factory(env, policy_fn, nroll=300, max_steps=20):
    def _make_ref(T_learned):
        return ref_empirical_from_rollouts(env, policy_fn, n_episodes=nroll, max_steps=max_steps)
    return _make_ref

# Build episode-wise empirical references
ref_fn_acq = ref_builder_factory(env,  greedy_right_down_policy, nroll=300, max_steps=20)
ref_fn_ext = ref_builder_factory(env2, greedy_right_down_policy, nroll=300, max_steps=20)

KL_acq = kl_over_time(T_series_acq, ref_fn_acq, use_js=False)
JS_acq = kl_over_time(T_series_acq, ref_fn_acq, use_js=True)
H_acq  = entropy_over_time(T_series_acq)
MS_acq = np.array([markovization_score(T) for T in T_series_acq])

KL_ext = kl_over_time(T_series_ext, ref_fn_ext, use_js=False)
JS_ext = kl_over_time(T_series_ext, ref_fn_ext, use_js=True)
H_ext  = entropy_over_time(T_series_ext)
MS_ext = np.array([markovization_score(T) for T in T_series_ext])
In [21]:
# Plot (one metric per figure)
import matplotlib.pyplot as plt
import numpy as np

def _offset_plot(ax, y1, y2, label1, label2):
    ax.plot(y1, label=label1)
    off = len(y1)
    ax.plot(off + np.arange(len(y2)), y2, label=label2)
    ax.legend()
    ax.set_xlabel("episode")

fig, ax = plt.subplots()
ax.set_title("KL (learned || empirical reference)")
_offset_plot(ax, KL_acq, KL_ext, "acq", "ext")
ax.set_ylabel("KL")
plt.show()

fig, ax = plt.subplots()
ax.set_title("JS distance")
_offset_plot(ax, JS_acq, JS_ext, "acq", "ext")
ax.set_ylabel("JS")
plt.show()

fig, ax = plt.subplots()
ax.set_title("Avg next-state entropy H(S'|S)")
_offset_plot(ax, H_acq, H_ext, "acq", "ext")
ax.set_ylabel("nats")
plt.show()

fig, ax = plt.subplots()
ax.set_title("Markovization score (1 - normalized H)")
_offset_plot(ax, MS_acq, MS_ext, "acq", "ext")
ax.set_ylabel("[0,1]")
plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Plots (separate panels with mean ± SE shading)¶

In [22]:
import numpy as np
import matplotlib.pyplot as plt

def _pad_runs(runs):
    L = max(len(r) for r in runs)
    out = np.full((len(runs), L), np.nan, dtype=float)
    for i, r in enumerate(runs):
        out[i, :len(r)] = r
    return out

def _plot_with_band(ax, runs, title, ylabel):
    M = _pad_runs(runs) if isinstance(runs, (list, tuple)) and len(runs)>0 and isinstance(runs[0], (list, np.ndarray)) else np.atleast_2d(runs)
    mean = np.nanmean(M, axis=0)
    se   = np.nanstd(M, axis=0, ddof=1) / np.sqrt(max(1, M.shape[0]))
    x = np.arange(len(mean))
    ax.plot(x, mean, lw=2.0, label="mean")
    ax.fill_between(x, mean - se, mean + se, alpha=0.2, label="±1 SE")
    ax.set_title(title)
    ax.set_xlabel("Episode")
    ax.set_ylabel(ylabel)
    ax.legend()

# Wrap single-run arrays as [array] so the function produces a zero-width band
KL_acq_runs = [KL_acq] if not isinstance(KL_acq, (list, tuple)) else KL_acq
JS_acq_runs = [JS_acq] if not isinstance(JS_acq, (list, tuple)) else JS_acq
H_acq_runs  = [H_acq]  if not isinstance(H_acq,  (list, tuple)) else H_acq
MS_acq_runs = [MS_acq] if not isinstance(MS_acq, (list, tuple)) else MS_acq

KL_ext_runs = [KL_ext] if not isinstance(KL_ext, (list, tuple)) else KL_ext
JS_ext_runs = [JS_ext] if not isinstance(JS_ext, (list, tuple)) else JS_ext
H_ext_runs  = [H_ext]  if not isinstance(H_ext,  (list, tuple)) else H_ext
MS_ext_runs = [MS_ext] if not isinstance(MS_ext, (list, tuple)) else MS_ext

# Acquisition-only figure
fig, axes = plt.subplots(2, 2, figsize=(10,6), constrained_layout=True)
_plot_with_band(axes[0,0], KL_acq_runs, "KL (acquisition)", "KL (nats)")
_plot_with_band(axes[0,1], JS_acq_runs, "JS (acquisition)", "JS")
_plot_with_band(axes[1,0], H_acq_runs,  "Avg H(S'|S) (acquisition)", "nats")
_plot_with_band(axes[1,1], MS_acq_runs, "Markovization (acquisition)", "[0,1]")
plt.show()

# Extinction-only figure
fig, axes = plt.subplots(2, 2, figsize=(10,6), constrained_layout=True)
_plot_with_band(axes[0,0], KL_ext_runs, "KL (extinction)", "KL (nats)")
_plot_with_band(axes[0,1], JS_ext_runs, "JS (extinction)", "JS")
_plot_with_band(axes[1,0], H_ext_runs,  "Avg H(S'|S) (extinction)", "nats")
_plot_with_band(axes[1,1], MS_ext_runs, "Markovization (extinction)", "[0,1]")
plt.show()
No description has been provided for this image
No description has been provided for this image